In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import glob
import matplotlib.image as mpimg
In [ ]:
# ----------------------------------------------------------
# Device
# ----------------------------------------------------------
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'Using Device {device}')
Using Device mps
In [ ]:
# ----------------------------------------------------------
# Hyperparameters (Variant of SAGAN without SN + TTUR)
# ----------------------------------------------------------
EPOCHS = 550
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS_IMG = 3
LATENT_DIM = 128
EMBED_DIM = 50
GEN_LR = 1e-4
DISC_LR = 4e-4
BETA1, BETA2 = 0.0, 0.9
CHECKPOINT_EVERY = 20
AUTOMOBILE_CLASS_IDX = 1
In [ ]:
# ----------------------------------------------------------
# Self-Attention Block
# ----------------------------------------------------------
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        b, c, h, w = x.size()
        query_out = self.query(x).view(b, -1, w*h)             # (b, c//8, h*w)
        key_out   = self.key(x).view(b, -1, w*h)               # (b, c//8, h*w)
        attn      = torch.bmm(query_out.permute(0, 2, 1), key_out)  # (b, h*w, h*w)
        attn      = torch.softmax(attn, dim=-1)
        value_out = self.value(x).view(b, c, w*h)              # (b, c, h*w)
        out       = torch.bmm(value_out, attn.permute(0, 2, 1)) # (b, c, h*w)
        out       = out.view(b, c, h, w)
        return self.gamma * out + x
In [ ]:
# ----------------------------------------------------------
# CIFAR-10 Data Loading
# ----------------------------------------------------------
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
automobile_indices = [i for i, (_, label) in enumerate(trainset) if label == AUTOMOBILE_CLASS_IDX]
automobile_dataset = Subset(trainset, automobile_indices)
trainloader = DataLoader(automobile_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
In [ ]:
# ----------------------------------------------------------
# Generator (SAGAN-style, no SN)
# ----------------------------------------------------------
class Generator(nn.Module):
    def __init__(self, latent_dim, embed_dim, num_classes=10):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)
        self.init_fc = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 4*4*512),
            nn.BatchNorm1d(4*4*512),
            nn.ReLU(True)
        )
        self.conv_blocks = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Self-Attention in the 16x16 feature map
            SelfAttention(256),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, CHANNELS_IMG, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        emb = self.label_emb(labels)
        x = torch.cat([z, emb], dim=1)
        x = self.init_fc(x).view(-1, 512, 4, 4)
        return self.conv_blocks(x)
In [ ]:
# ----------------------------------------------------------
# Discriminator (SAGAN-style, no SN)
# ----------------------------------------------------------
class Discriminator(nn.Module):
    def __init__(self, embed_dim, num_classes=10):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(CHANNELS_IMG, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            # Self-Attention in the 16x16 feature map
            SelfAttention(128),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Linear(512*2*2 + embed_dim, 1)

    def forward(self, x, labels):
        bsz = x.size(0)
        emb = self.label_emb(labels)
        features = self.conv_blocks(x).view(bsz, -1)
        combined = torch.cat([features, emb], dim=1)
        return self.fc(combined)
In [ ]:
# ----------------------------------------------------------
# Initialize Model, Loss, Optimizers (TTUR)
# ----------------------------------------------------------
gen = Generator(LATENT_DIM, EMBED_DIM).to(device)
disc = Discriminator(EMBED_DIM).to(device)
criterion = nn.BCEWithLogitsLoss()

opt_gen = optim.Adam(gen.parameters(), lr=GEN_LR, betas=(BETA1, BETA2))
opt_disc = optim.Adam(disc.parameters(), lr=DISC_LR, betas=(BETA1, BETA2))

checkpoint_path = "adl_part_3.pt"
start_epoch = 1
In [ ]:
# ----------------------------------------------------------
# Check for Existing Checkpoint
# ----------------------------------------------------------
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    gen.load_state_dict(ckpt["gen_state_dict"])
    disc.load_state_dict(ckpt["disc_state_dict"])
    opt_gen.load_state_dict(ckpt["opt_gen_state_dict"])
    opt_disc.load_state_dict(ckpt["opt_disc_state_dict"])
    start_epoch = ckpt["epoch"] + 1
In [ ]:
# ----------------------------------------------------------
# Utility: Generate & Show Samples
# ----------------------------------------------------------
def generate_and_show_samples(epoch):
    gen.eval()
    with torch.no_grad():
        z = torch.randn(10, LATENT_DIM, device=device)
        labels = torch.full((10,), AUTOMOBILE_CLASS_IDX, dtype=torch.long, device=device)
        samples = gen(z, labels).cpu()
    samples = (samples + 1) / 2.0
    fig, axes = plt.subplots(1, 10, figsize=(22, 2.4))
    for i in range(10):
        img = samples[i].permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.suptitle(f"Epoch {epoch}: SAGAN-Variant Samples (Automobile)", fontsize=14)
    plt.savefig(f'task3/automobile_gan_losses_{epoch}.png')
    plt.show()
    gen.train()
In [ ]:
# ----------------------------------------------------------
# Compute IS & FID
# ----------------------------------------------------------
def compute_is_fid(generator, loader, n_samples=2000):
    is_metric = InceptionScore().to("cpu")
    fid_metric = FrechetInceptionDistance().to("cpu")
    generator.eval()

    real_count = 0
    for real_imgs, _ in loader:
        real_imgs = real_imgs.to(device)
        real_imgs_uint8 = (((real_imgs * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
        fid_metric.update(real_imgs_uint8, real=True)
        real_count += real_imgs.size(0)
        if real_count >= n_samples:
            break

    fake_count = 0
    while fake_count < n_samples:
        z = torch.randn(BATCH_SIZE, LATENT_DIM, device=device)
        labels = torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long, device=device)
        with torch.no_grad():
            fake_out = generator(z, labels)
        fake_out_uint8 = (((fake_out * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
        is_metric.update(fake_out_uint8)
        fid_metric.update(fake_out_uint8, real=False)
        fake_count += BATCH_SIZE

    inception_score = is_metric.compute()  # (mean, std)
    fid_score = fid_metric.compute()
    generator.train()
    return inception_score[0].item(), fid_score.item()
In [ ]:
# ----------------------------------------------------------
# Training Loop
# ----------------------------------------------------------
for epoch in range(start_epoch, EPOCHS + 1):
    for _, (real, labels) in enumerate(trainloader):
        real, labels = real.to(device), labels.to(device)
        bsz = real.size(0)

        # --------------------
        # Train Discriminator
        # --------------------
        disc.zero_grad()
        noise = torch.randn(bsz, LATENT_DIM, device=device)
        rand_labels = torch.randint(0, 10, (bsz,), dtype=torch.long, device=device)

        pred_real = disc(real, labels)
        loss_real = criterion(pred_real, torch.ones_like(pred_real))

        fake = gen(noise, rand_labels)
        pred_fake = disc(fake.detach(), rand_labels)
        loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake))

        lossD = loss_real + loss_fake
        lossD.backward()
        opt_disc.step()

        # ----------------
        # Train Generator
        # ----------------
        gen.zero_grad()
        pred_gen = disc(fake, rand_labels)
        lossG = criterion(pred_gen, torch.ones_like(pred_gen))
        lossG.backward()
        opt_gen.step()

    print(f"[Epoch {epoch}/{EPOCHS}]  LossD: {lossD.item():.4f}  LossG: {lossG.item():.4f}")

    if epoch % CHECKPOINT_EVERY == 0:
        save_data = {
            "epoch": epoch,
            "gen_state_dict": gen.state_dict(),
            "disc_state_dict": disc.state_dict(),
            "opt_gen_state_dict": opt_gen.state_dict(),
            "opt_disc_state_dict": opt_disc.state_dict()
        }
        torch.save(save_data, checkpoint_path)
        print(f"[epoch={epoch}]Checkpoint saved: {checkpoint_path}")
        generate_and_show_samples(epoch)
        is_val, fid_val = compute_is_fid(gen, trainloader)
        print(f"==> Epoch {epoch}: Inception Score = {is_val:.4f}, FID = {fid_val:.4f}")

print("Training complete!")
[Epoch 501/550]  LossD: 0.0084  LossG: 30.5150
[Epoch 502/550]  LossD: 0.0000  LossG: 26.5804
[Epoch 503/550]  LossD: 0.0515  LossG: 17.1791
[Epoch 504/550]  LossD: 0.0021  LossG: 41.7324
[Epoch 505/550]  LossD: 0.0000  LossG: 28.7694
[Epoch 506/550]  LossD: 0.0006  LossG: 33.8097
[Epoch 507/550]  LossD: 0.0705  LossG: 35.9834
[Epoch 508/550]  LossD: 0.0033  LossG: 22.5316
[Epoch 509/550]  LossD: 0.0017  LossG: 41.0436
[Epoch 510/550]  LossD: 0.0824  LossG: 29.4740
[Epoch 511/550]  LossD: 1.8864  LossG: 20.3700
[Epoch 512/550]  LossD: 0.0001  LossG: 55.0313
[Epoch 513/550]  LossD: 0.0020  LossG: 50.0931
[Epoch 514/550]  LossD: 0.0071  LossG: 29.5711
[Epoch 515/550]  LossD: 0.0000  LossG: 47.5927
[Epoch 516/550]  LossD: 0.0000  LossG: 37.7736
[Epoch 517/550]  LossD: 0.0172  LossG: 12.9776
[Epoch 518/550]  LossD: 0.0001  LossG: 65.8510
[Epoch 519/550]  LossD: 0.0016  LossG: 16.5696
[Epoch 520/550]  LossD: 0.0647  LossG: 55.0012
[epoch=520]Checkpoint saved: adl_part_3.pt
No description has been provided for this image
/Users/shivamsahil/Downloads/bits/assignments/venv/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028
==> Epoch 520: Inception Score = 1.3875, FID = 315.2394
[Epoch 521/550]  LossD: 0.0049  LossG: 35.1182
[Epoch 522/550]  LossD: 0.0000  LossG: 54.1443
[Epoch 523/550]  LossD: 0.0015  LossG: 37.0467
[Epoch 524/550]  LossD: 0.0001  LossG: 53.0654
[Epoch 525/550]  LossD: 0.0000  LossG: 50.4639
[Epoch 526/550]  LossD: 0.0003  LossG: 49.2188
[Epoch 527/550]  LossD: 0.0005  LossG: 43.3066
[Epoch 528/550]  LossD: 0.0833  LossG: 23.3325
[Epoch 529/550]  LossD: 0.0024  LossG: 56.1662
[Epoch 530/550]  LossD: 0.0004  LossG: 43.3856
[Epoch 531/550]  LossD: 0.0001  LossG: 69.2146
[Epoch 532/550]  LossD: 0.0122  LossG: 28.4343
[Epoch 533/550]  LossD: 0.0000  LossG: 54.0010
[Epoch 534/550]  LossD: 0.4353  LossG: 45.7758
[Epoch 535/550]  LossD: 0.0683  LossG: 62.5474
[Epoch 536/550]  LossD: 0.7499  LossG: 92.6146
[Epoch 537/550]  LossD: 0.0002  LossG: 46.4222
[Epoch 538/550]  LossD: 0.0801  LossG: 44.5221
[Epoch 539/550]  LossD: 0.0000  LossG: 26.9270
[Epoch 540/550]  LossD: 0.0000  LossG: 57.2657
[epoch=540]Checkpoint saved: adl_part_3.pt
No description has been provided for this image
==> Epoch 540: Inception Score = 1.7105, FID = 339.5601
[Epoch 541/550]  LossD: 0.0017  LossG: 32.2723
[Epoch 542/550]  LossD: 0.0193  LossG: 47.9374
[Epoch 543/550]  LossD: 0.0074  LossG: 33.3457
[Epoch 544/550]  LossD: 0.0010  LossG: 52.6690
[Epoch 545/550]  LossD: 0.0003  LossG: 24.1623
[Epoch 546/550]  LossD: 0.0000  LossG: 58.0807
[Epoch 547/550]  LossD: 0.0000  LossG: 35.8908
[Epoch 548/550]  LossD: 0.0000  LossG: 38.9137
[Epoch 549/550]  LossD: 0.0001  LossG: 61.0286
[Epoch 550/550]  LossD: 0.0196  LossG: 30.3804
Training complete!
In [ ]:
directory = r'task3'
# Define a custom sort key that extracts the epoch number
def extract_epoch(filename):
    base = os.path.basename(filename)
    try:
        epoch_str = base.split('automobile_gan_losses_')[1].split('.')[0]
        return int(epoch_str)
    except (IndexError, ValueError):
        return float('inf')
png_files = glob.glob(os.path.join(directory, '*.png'))
# Sort the list numerically by epoch number
png_files = sorted(png_files, key=extract_epoch)

# Check if any PNG files are found
if not png_files:
    print("No PNG files found in the directory:", directory)
else:
    n = len(png_files)

    # Increase the figure size to accommodate full screen-like display
    fig, axs = plt.subplots(n, 1, figsize=(22, 2.4 * n))

    # If only one image, wrap axs into a list for consistency
    if n == 1:
        axs = [axs]

    mng = plt.get_current_fig_manager()
    try:
        mng.window.state('zoomed')
    except AttributeError:
        try:
            mng.window.showMaximized()
        except Exception:
            pass

    # Loop through each file and display the image
    for ax, file in zip(axs, png_files):
        img = mpimg.imread(file)
        ax.imshow(img, aspect='auto')
        ax.axis('off')
        ax.set_title(os.path.basename(file), fontsize=14)

    plt.tight_layout()
    plt.show()
No description has been provided for this image
In [ ]:
# Install necessary packages
!apt-get install texlive texlive-xetex texlive-latex-extra pandoc
!pip install pypandoc

# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

# Copy the notebook to the current directory
!cp 'drive/My Drive/Colab Notebooks/Assignment2_Group75_Task3.ipynb' ./

# Convert the notebook to PDF while keeping the code and output
!jupyter nbconvert --to html "Assignment2_Group75_Task3.ipynb"


# Download the generated PDF
from google.colab import files
files.download('Assignment2_Group75_Task3.html')
done.